import numpy as np  # NumPy是一个用于科学计算的基础包,用于处理大型多维数组和矩阵
import faiss  # FAISS库用于高效的相似度搜索和稠密向量的聚类

# 定义FaissKNeighbors类，用于执行基于FAISS的K近邻搜索
class FaissKNeighbors:
    # 类初始化函数：初始化k值，FAISS资源对象res，以及用于存储数据的索引
    def __init__(self, k=1, res=None):
        self.index = None  # 用于存储训练数据的索引
        self.y = None  # 用于存储训练数据的标签
        self.k = k  # 最近邻个数
        self.res = res  # FAISS GPU资源对象

    # 训练函数：将训练数据加入到FAISS索引中
    def fit(self, X, y):
        # 创建一个基于 L2 距离的平面索引
        d = X.shape[1]
        cpu_index = faiss.IndexFlatL2(d)

        # 如果提供了 GPU 资源，则将索引复制到 GPU
        if self.res is not None:
            try:
                # 将 CPU 索引迁移到 GPU
                gpu_index = faiss.index_cpu_to_gpu(self.res, 0, cpu_index)
                self.index = gpu_index
            except Exception:
                # 如果转移失败，回退到 CPU 索引
                self.index = cpu_index
        else:
            self.index = cpu_index

        # 添加向量到索引 (确保为 float32)
        self.index.add(X.astype(np.float32))

        # 保存标签数组
        self.y = np.array(y)

    # 预测函数：对新的数据集X进行分类预测
    def predict(self, X):
        # 搜索X中每个向量的k个最近邻
        distances, indices = self.index.search(X.astype(np.float32), self.k)
        votes = self.y[indices]  # 根据索引获得最近邻的标签
        # 通过投票机制得出最终预测的标签
        predictions = np.array([np.argmax(np.bincount(vote)) for vote in votes])
        return predictions

    # 评分函数：计算预测准确率
    def score(self, X, y):
        preds = self.predict(X)
        y = np.array(y)
        correct = np.sum(preds == y)
        return float(correct) / len(y) if len(y) > 0 else 0.0
